# Copyright (c) 2025 Microsoft Corporation.
# Licensed under the MIT License

"""
LiteLLM cache key generation.

Modeled after the fnllm cache key generation.
https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50
"""

import hashlib
import inspect
import json
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel

if TYPE_CHECKING:
    from graphrag.config.models.language_model_config import LanguageModelConfig


_CACHE_VERSION = 3
"""
If there's a breaking change in what we cache, we should increment this version number to invalidate existing caches.

fnllm was on cache version 2 and though we generate
similar cache keys, the objects stored in cache by fnllm and litellm are different.
Using litellm model providers will not be able to reuse caches generated by fnllm
thus we start with version 3 for litellm.
"""


def get_cache_key(
    model_config: "LanguageModelConfig",
    prefix: str,
    messages: str | None = None,
    input: str | None = None,
    **kwargs: Any,
) -> str:
    """Generate a cache key based on the model configuration and input arguments.

    Modeled after the fnllm cache key generation.
    https://github.com/microsoft/essex-toolkit/blob/23d3077b65c0e8f1d89c397a2968fe570a25f790/python/fnllm/fnllm/caching/base.py#L50

    Args
    ____
        model_config: The configuration of the language model.
        prefix: A prefix for the cache key.
        **kwargs: Additional model input parameters.

    Returns
    -------
        `{prefix}_{data_hash}_v{version}` if prefix is provided.
    """
    cache_key: dict[str, Any] = {
        "parameters": _get_parameters(model_config, **kwargs),
    }

    if messages is not None and input is not None:
        msg = "Only one of 'messages' or 'input' should be provided."
        raise ValueError(msg)

    if messages is not None:
        cache_key["messages"] = messages
    elif input is not None:
        cache_key["input"] = input
    else:
        msg = "Either 'messages' or 'input' must be provided."
        raise ValueError(msg)

    data_hash = _hash(json.dumps(cache_key, sort_keys=True))

    name = kwargs.get("name")

    if name:
        prefix += f"_{name}"

    return f"{prefix}_{data_hash}_v{_CACHE_VERSION}"


def _get_parameters(
    model_config: "LanguageModelConfig",
    **kwargs: Any,
) -> dict[str, Any]:
    """Pluck out the parameters that define a cache key.

    Use the same parameters as fnllm except request timeout.
    - embeddings: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/embeddings/parameters.py#L12
    - chat: https://github.com/microsoft/essex-toolkit/blob/main/python/fnllm/fnllm/openai/types/chat/parameters.py#L25

    Args
    ____
        model_config: The configuration of the language model.
        **kwargs: Additional model input parameters.

    Returns
    -------
        dict[str, Any]: A dictionary of parameters that define the cache key.
    """
    parameters = {
        "model": model_config.deployment_name or model_config.model,
        "frequency_penalty": model_config.frequency_penalty,
        "max_tokens": model_config.max_tokens,
        "max_completion_tokens": model_config.max_completion_tokens,
        "n": model_config.n,
        "presence_penalty": model_config.presence_penalty,
        "temperature": model_config.temperature,
        "top_p": model_config.top_p,
        "reasoning_effort": model_config.reasoning_effort,
    }
    keys_to_cache = [
        "function_call",
        "functions",
        "logit_bias",
        "logprobs",
        "parallel_tool_calls",
        "seed",
        "service_tier",
        "stop",
        "tool_choice",
        "tools",
        "top_logprobs",
        "user",
        "dimensions",
        "encoding_format",
    ]
    parameters.update({key: kwargs.get(key) for key in keys_to_cache if key in kwargs})

    response_format = kwargs.get("response_format")
    if inspect.isclass(response_format) and issubclass(response_format, BaseModel):
        parameters["response_format"] = str(response_format)
    elif response_format is not None:
        parameters["response_format"] = response_format

    return parameters


def _hash(input: str) -> str:
    """Generate a hash for the input string."""
    return hashlib.sha256(input.encode()).hexdigest()
